import os
import re
import time
import json
import random
import openai
import tiktoken
from datetime import datetime
from openai import OpenAI
from tqdm import tqdm, trange
import numpy as np
from collections import Counter
from sklearn.metrics.pairwise import cosine_similarity, cosine_distances
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN
from scipy.spatial.distance import squareform

from utils.util import read_json, write_json, read_txt, is_json
from utils.token_count_decorator import token_count_decorator
from planning.src.protocol import Protocol
from planning.data_process.download import Downloader
from planning.data_process.metrics import (
    get_funcs_jaccard_sim, 
    get_levenshtein_distance,
    get_funcs_scibert_sim,
    get_embedding,
)

class Processer:
    def __init__(self):
        self.data_path = "planning/data/"
        self.dataset_path = "dataset/bioprot2/"
        self.downloader = Downloader()
        self.data = self.load_data()
        self.id_to_protocol = {protocol["id"]: protocol for _, protocol in self.data.items()}

    def load_data(self):
        protocols = {}
        for filename in os.listdir(self.dataset_path):
            file_path = os.path.join(self.dataset_path, filename)
            protocols[filename] = read_json(file_path)
        return dict(sorted(protocols.items()))
    
    def extract_title_description(self):
        protocols_metadata = []
        for filename, protocol in self.data.items():
            if "ai_generated_description" in protocol:
                protocols_metadata.append({
                    "id": int(filename.split(".json")[0]),
                    "title": protocol["title"],
                    "description": protocol["ai_generated_description"]
                })
            else:
                print(protocol["original description"])
        
        print(len(protocols_metadata))
        write_json(self.data_path + "protocols_metadata.json", protocols_metadata)

    def supplement_from_metadata(self):
        filenames = list(self.data.keys())[80:100]
        for filename in tqdm(filenames):
            protocol = self.data[filename]
            title = protocol["title"]
            id = int(filename.split(".json")[0])
            response = self.downloader.get_protocol(id=id)
            metadata = response["payload"]
            
            original_title = metadata["title"]
            if original_title != title:
                print("Not matching!", filename)
                print("title in bioprot: ", title)
                print("protocol searched: ", original_title)
                continue
            
            publish_time = metadata.get("published_on", "")
            doi = metadata.get("doi", "")

            protocol = {
                **protocol, 
                "publish_time": publish_time,
                "doi": doi
            }

            write_json(os.path.join(self.dataset_path, filename), protocol)

    def merge_metadata(self):
        protocols_metadata = []

        old_metadata = read_json(self.data_path + "protocols_metadata_classified.json")
        title_to_category = {metadata["title"]: metadata["classification"] for metadata in old_metadata}
        
        for filename, protocol in self.data.items():
            if "ai_generated_description" in protocol:
                pt = datetime.fromtimestamp(protocol["publish_time"])
                protocols_metadata.append({
                    "id": int(filename.split(".json")[0]),
                    "title": protocol["title"],
                    "description": protocol["ai_generated_description"],
                    "category": title_to_category[protocol["title"]],
                    "publish_time": pt.strftime("%Y-%m-%d %H:%M:%S")
                })

        print(len(protocols_metadata))
        write_json(self.data_path + "protocols_metadata.json", protocols_metadata)

    def merge_origin(self):
        metadata_list = read_json(self.data_path + "protocols_metadata.json")
        for metadata in metadata_list:
            category = metadata["category"]
            id = str(int(metadata["id"]))
            origin_path = f"{self.dataset_path}{id}.json"
            old_origin = read_json(origin_path)
            old_origin["category"] = category
            write_json(f"dataset/bioprot2/{id}.json", old_origin)

    def protocol_sim_matrix(self):
        # pseudocode_dict = {
        #     protocol["id"]: protocol["edited_pseudocode"] if protocol["edited_pseudocode"] else protocol["generated_pseudocode"]
        #     for _, protocol in self.data.items()
        # }
        pseudocode_dict = {
            protocol["id"]: protocol["generated_pseudocode"]
            for _, protocol in self.data.items()
        }
        pseudocode_list = list(pseudocode_dict.values())
        pseudofunctions = [pseudocode.split("# Protocol steps")[0] for pseudocode in pseudocode_list]
        n = len(pseudocode_list)

        sim_matrix = np.zeros((n, n))
        # 不对称，sim_matrix[i][j] != sim_matrix[j][i]
        for i, pseudocode_1 in enumerate(tqdm(pseudocode_list)):
            for j, pseudocode_2 in enumerate(tqdm(pseudocode_list, leave=False)):
                sim_matrix[i][j] = self.protocol_similarity(pseudocode_1, pseudocode_2, pseudofunctions[j])

        min_val = np.min(sim_matrix)
        max_val = np.max(sim_matrix)
        normalized_matrix = (sim_matrix - min_val) / (max_val - min_val)

        return normalized_matrix

    def protocol_similarity(self, pseudocode_1, pseudocode_2, pseudofunctions_2=''):
        funcs_jaccard_sim = get_funcs_jaccard_sim(pseudocode_1, pseudocode_2)
        # argkeys_jaccard_sim = get_argkeys_jaccard_sim(pseudocode_1, pseudocode_2)
        levenshtein_distance = get_levenshtein_distance(pseudofunctions_2, pseudocode_1, pseudocode_2)
        scaled_levenshtein_distance = levenshtein_distance / len(
            pseudocode_1.split("\n")
        )
        funcs_scibert_sim = get_funcs_scibert_sim(pseudocode_1, pseudocode_2)
        
        score = funcs_jaccard_sim - scaled_levenshtein_distance + funcs_scibert_sim

        return score
    
    def cluster_and_cut(self, sim_matrix, threshold):
        '''
        Return label_count and label_to_ids
        '''
        distance_matrix = 1 - (sim_matrix + sim_matrix.T) / 2
        Z = linkage(distance_matrix, method='average')
        cluster_labels = fcluster(Z, t=threshold, criterion='distance')
        
        label_count = dict(sorted(Counter(cluster_labels).items(), key=lambda item: item[1], reverse=True))
        label_to_ids = {}
        for i, label in enumerate(cluster_labels):
            id = list(self.id_to_protocol)[i]
            label_to_ids.setdefault(label, []).append(id)
        
        return label_count, label_to_ids
    
    def dump_description_embedding(self):
        dataset = list(self.id_to_protocol.values())
        embeddings = []
        for data in tqdm(dataset):
            desc = data["ai_generated_description"]
            embeddings.append(get_embedding(desc))
        embeddings = np.array(embeddings)
        np.save(self.data_path+"desc_emb.npy", embeddings)
        print(embeddings.shape)
        
    def dump_sim_matrix(self):
        desc_emb = np.load(self.data_path+"desc_emb.npy")
        sim_matrix = cosine_similarity(desc_emb)
        np.save(self.data_path+"desc_sim_matrix.npy", sim_matrix)

    def add_idx_to_origin(self):
        i = 0
        for filename, protocol in self.data.items():
            protocol["idx"] = str(i)
            write_json(self.dataset_path+filename, protocol)
            i += 1
        print(i)


class DataRefiner:
    def __init__(self) -> None:
        self.data_path = "planning/data/"
        self.dataset_path = "dataset/bioprot2/"
        self.dataset = self.load_dataset()
        self.id_list = [protocol.id for protocol in self.dataset]
        
    def load_dataset(self):
        dataset = sorted([
            protocol for filename in os.listdir(self.dataset_path)
            if (protocol := Protocol.fromjson(read_json(os.path.join(self.dataset_path, filename))))
        ])
        return dataset
    
    def dump_bioprot_emb_matrix(self, model="scibert"):
        if model.endswith("bert"):
            embeddings = [get_embedding(protocol.description) for protocol in tqdm(self.dataset)]
        elif model.startswith("text-embedding"):
            embeddings = [get_openai_embedding(protocol.description, model=model) for protocol in tqdm(self.dataset)]
        np.save(self.data_path+f"desc_emb_{model}.npy", embeddings)
    
    def cluster_and_cut(self, embed_matrix, threshold):
        distance_matrix = squareform(cosine_distances(embed_matrix))
        Z = linkage(distance_matrix, method="average")
        labels = fcluster(Z, t=threshold, criterion='distance')
        return self.__cope_cluster_result(labels)
    
    def cluster_kmeans(self, embed_matrix, n_clusters=3):
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        kmeans.fit(embed_matrix)
        labels = kmeans.labels_
        return self.__cope_cluster_result(labels)
    
    def cluster_agglo(self, embed_matrix, n_clusters=3):
        clustering = AgglomerativeClustering(n_clusters=n_clusters, affinity='cosine', linkage='average')
        clustering.fit(embed_matrix)
        labels = clustering.labels_
        return self.__cope_cluster_result(labels)

    def cluster_dbscan(self, embed_matrix, eps=0.5, min_samples=3):
        distance_matrix = cosine_distances(embed_matrix)
        db = DBSCAN(eps=eps, min_samples=min_samples, metric="precomputed")
        db.fit(distance_matrix)
        labels = db.labels_
        return self.__cope_cluster_result(labels)
    
    def __cope_cluster_result(self, labels):
        tot_label_count = dict(sorted(Counter(labels[:-89]).items(), key=lambda item: item[1], reverse=True))
        bioprot_labels = labels[-89:]
        label_count = dict(sorted(Counter(bioprot_labels).items(), key=lambda item: item[1], reverse=True))
        label_to_protocols = {}
        for i, label in enumerate(bioprot_labels):
            label_to_protocols.setdefault(label, []).append(self.dataset[i])
        return tot_label_count, label_count, label_to_protocols
    
    def get_category_centroid(self):
        embeddings = np.load("planning/data/title_emb.npy")
        category_to_embeddings = {}
        for i, protocol in enumerate(self.dataset):
            category_to_embeddings.setdefault(protocol.category, []).append(embeddings[i])
        return {category: np.mean(vectors, axis=0) for category, vectors in category_to_embeddings.items()}
    
    def match_test(self):
        category_centroids: dict = self.get_category_centroid()

        genetics_title_embs = np.load("planning/data/genetics_sampled_800_emb.npy")

        # Prepare a dictionary to store the results
        classification_results = []
        multi_sims = {}

        # For each embedding in genetics_title_embs
        for emb in tqdm(genetics_title_embs):
            best_category = None
            best_similarity = -1  # Start with a very low similarity
            
            # Compare with each category centroid
            for category, centroid in category_centroids.items():
                # Compute the cosine similarity between the embedding and the centroid
                similarity = cosine_similarity([emb], [centroid])[0][0]
                
                multi_sims.setdefault(category, []).append(similarity)
                
                # Update the best category if this similarity is higher
                if similarity > best_similarity:
                    best_similarity = similarity
                    best_category = category

            classification_results.append(best_category)

        count = dict(sorted(Counter(classification_results).items(), key=lambda item: item[1], reverse=True))
        print(count)

        print({category: len(sim_list) for category, sim_list in multi_sims.items()})

        avg_sims = {category: np.mean(sim_list, axis=0) for category, sim_list in multi_sims.items()}
        print(avg_sims)

    def match_test_2(self):
        category_centroids: dict = self.get_category_centroid()
        match_dict = {}
        for _, name in tqdm(name_mapping.items()):
            embeddings = np.load(f"planning/data/{name}_sampled_200_emb.npy")
            multi_sims = {}
            for emb in embeddings:
                for category, centroid in category_centroids.items():
                    similarity = cosine_similarity([emb], [centroid])[0][0]
                    multi_sims.setdefault(category, []).append(similarity)
            
            match_dict[name] = dict(sorted({category: str(np.mean(sim_list, axis=0)) for category, sim_list in multi_sims.items()}.items(), key=lambda item: item[1], reverse=True))

        write_json("test.json", match_dict)

    def filter_dataset(self, mode="avg"):
        genetics_desc_embs = np.load("planning/data/genetics_sampled_800_emb_text-embedding-3-large.npy")
        bioprot_desc_embs = np.load("planning/data/desc_emb_text-embedding-3-large.npy")
        if mode == "avg":
            sim_matrix = cosine_similarity(bioprot_desc_embs, genetics_desc_embs)
            avg_sims = np.mean(sim_matrix, axis=1)
            return avg_sims
        elif mode == "centroid":
            centroid = np.mean(genetics_desc_embs, axis=0)
            sim_matrix = cosine_similarity([centroid], bioprot_desc_embs)
            return sim_matrix.squeeze()
    
    def select_protocol(self, num=9):
        avg_sims = self.filter_dataset(mode="centroid")
        sorted_indices = np.argsort(avg_sims)[::-1]
        sorted_matrix = avg_sims[sorted_indices]
        indices = sorted_indices[:num]
        filtered_data = [self.dataset[i] for i in indices]
        return filtered_data
    
    def find_cases(self):
        bioprot_desc_embs = np.load("planning/data/desc_emb_text-embedding-3-large.npy")
        sim_matrix = cosine_similarity(bioprot_desc_embs)
    
    def get_shortest_protocol(self):
        min_steps = 100000
        example_protocol = None
        for protocol in self.dataset:
            steps = self.__convert_to_sentence_list(protocol.steps)
            if (length := len(steps)) < min_steps and length > 5:
                min_steps = length
                example_protocol = protocol
        return example_protocol
    
    def __convert_to_sentence_list(self, steps):
        if not steps:
            return []
        
        sentences = [sentence.strip() for sentence in steps.split("\n") if sentence.strip()]
        operation_steps = [sentence for sentence in sentences if re.match(r'^\d+\.', sentence)]
        return operation_steps
    
    def pseudocode_to_json(self):
        prompt = read_txt("planning/data/prompt/pseudocode_to_json.txt")
        for protocol in tqdm(self.dataset):
            content = prompt.replace("{pseudocode}", protocol.pseudocode)
            for _ in range(5):
                response = self.__chatgpt_function(content=content)
                program = re.findall(r'```json([^`]*)```', response, re.DOTALL)
                if len(program) > 0 and is_json(plan := program[0].strip()):
                   path = self.dataset_path+f"{protocol.id}.json"
                   protocol_ini = read_json(path)
                   protocol_ini["program"] = json.loads(plan)
                   write_json(path, protocol_ini)
                   break


    @token_count_decorator(flow="together", batch=False)
    def __chatgpt_function(self, content, gpt_model="gpt-4o-mini"):
        while True:
            try:
                client = OpenAI(
                    api_key=os.environ.get("OPENAI_API_KEY"),
                )
                chat_completion = client.chat.completions.create(
                    messages=[
                        {"role": "user", "content": content}
                    ],
                    model=gpt_model
                )
                return chat_completion.choices[0].message.content
            except openai.APIError as error:
                print(error)


class Summarizer:
    def __init__(self, domain):
        self.domain = domain
        self.data_path = f"planning/data/corpus/{domain}/"
        self.prompt = read_txt("planning/data/prompt/protocol_summary.txt")
        self.batch_input_path = f"planning/data/temp_batch/annotation_batch_input_{domain}.jsonl"
        self.batch_output_path = f"planning/data/temp_batch/annotation_batch_output_{domain}.jsonl"
        self.concurrent = 250
        

    def load_data(self):
        for filename in os.listdir(self.data_path):
            yield read_json(os.path.join(self.data_path, filename))

    def annotate(self):
        data_generator = self.load_data()
        batch = []

        total_protocols = len(os.listdir(self.data_path))
    
        with tqdm(total=total_protocols, desc="Annotating protocols") as pbar:
            while True:
                try:
                    protocol = next(data_generator)
                    protocol_id = str(protocol["id"])
                    output_file = f"planning/data/corpus/{self.domain}_summarized/{protocol_id}.json"
                    if os.path.exists(output_file):
                        pbar.update(1)
                        continue
                    batch.append(protocol)
                    if len(batch) >= self.concurrent:
                        self.process_batch(batch)
                        pbar.update(len(batch))
                        batch.clear()
                except StopIteration:
                    if batch:
                        self.process_batch(batch)
                        pbar.update(len(batch))
                    break

        print("Annotation is finished")

    def process_batch(self, batch):
        self.empty_jsonl_contents()
        
        for i, protocol in enumerate(batch):
            title = protocol["title"]
            steps = "".join(protocol["procedures"])
            self.gpt_batch_store(self.prompt.replace("{title}", title).replace("{protocol_steps}", steps), index=str(i))
        
        print("Batch stored")
        batch_obj = self.gpt_batch_call()
        print("Batch called, waiting for results...")

        results = self.gpt_batch_result(batch_obj.id)
        print("Results received")

        for result in results:
            idx = int(result["custom_id"])
            summary = result["text"]
            token_num = result["length"]

            protocol = batch[idx]
            protocol["ai_generated_description"] = summary
            protocol["ai_generated_description length in tokens"] = token_num
            id = str(protocol["id"])
            write_json(f"planning/data/corpus/{self.domain}_summarized/{id}.json", protocol)  
        
        time.sleep(8)

    def judge_process(self, reply, model="gpt-4o-mini"):
        if not reply:
            return None, None
        encoding = tiktoken.encoding_for_model(model)
        tokens = encoding.encode(reply)
        return reply, len(tokens)

    @token_count_decorator
    def gpt_batch_store(self, content, index):
        standard = {
            "custom_id": "", 
            "method": "POST", 
            "url": "/v1/chat/completions", 
            "body": {
                "model": "gpt-4o-mini", 
                "messages": [
                    {"role": "system", "content": "You are an experimental scientist in the fields of biology."},
                    {"role": "user", "content": ""}
                ],
                "max_tokens": 1000
            }
        }
        prompt_unit = standard.copy()
        prompt_unit["body"]["messages"][1]["content"] = content
        prompt_unit["custom_id"] = index
        with open(self.batch_input_path, 'a') as file:
            json_line = json.dumps(prompt_unit)
            file.write(json_line + '\n')

    def gpt_batch_call(self):
        client = OpenAI()
        batch_input_file = client.files.create(
            file=open(self.batch_input_path, "rb"),
            purpose="batch"
        )
        batch_input_file_id = batch_input_file.id
        batch_obj = client.batches.create(
            input_file_id=batch_input_file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
        )
        return batch_obj

    def gpt_batch_result(self, batch_id):
        client = OpenAI()
        while True:
            batch = client.batches.retrieve(batch_id)
            if batch.status == "completed":
                output_file_id = batch.output_file_id
                result = client.files.content(output_file_id).content
                result_file_name = self.batch_output_path
                with open(result_file_name, "wb") as file:
                    file.write(result)
                results = []
                with open(result_file_name, "r") as file:
                    for line in file:
                        json_object = json.loads(line.strip())
                        results.append(json_object)
                return self.process_results(results)
            elif batch.status in ["failed", "expired", "cancelled", "cancelling"]:
                print(f"Batch {batch.status}")
                return []
            else:
                time.sleep(3)

    def process_results(self, results):
        return_results = []
        for res in results:
            response = self.extract_reply(res)
            response, token_num = self.judge_process(response.strip())
            if response:
                return_results.append({
                    "custom_id": res["custom_id"],
                    "text": response,
                    "length": token_num
                })
            else:
                return_results.append({
                    "custom_id": res["custom_id"],
                    "text": "NONE",
                    "length": "NONE"
                })
        return return_results
    
    @token_count_decorator(flow="output")
    def extract_reply(self, res):
        return res["response"]["body"]["choices"][0]["message"]["content"]

    def empty_jsonl_contents(self):
        if os.path.exists(self.batch_input_path):
            with open(self.batch_input_path, 'w') as file:
                file.write('')

name_mapping = {
    "Molecular Biology & Genetics": "molecular_biology_and_genetics",
    "Biomedical & Clinical Research": "biomedical_and_clinical_research",
    "Ecology & Environmental Biology": "ecology_and_environmental_biology",
    "Bioengineering & Technology": "bioengineering_and_technology",
}

def sample_original_protocols():
    protocols = []
    all_protocols = []
    for file in os.listdir("dataset/original_protocol"):
        protocol = read_json("dataset/original_protocol/" + file)
        if "Molecular Biology & Genetics" in protocol["bigAreas"] and "Bioinformatics & Computational Biology" not in protocol["bigAreas"]:
            all_protocols.append(protocol)

    if len(all_protocols) < 800:
        raise ValueError(f"符合条件的协议数量不足，只有 {len(all_protocols)} 个。")
    
    protocols = random.sample(all_protocols, 800)
    for protocol in protocols:
        id = str(protocol["id"])
        write_json(f"planning/data/genetics_sampled/protocol_{id}.json", protocol)

def sample_original_protocols_2():
    protocols = []
    all_protocols = []
    
    # 读取原始协议
    for file in os.listdir("dataset/original_protocol"):
        protocol = read_json("dataset/original_protocol/" + file)
        if "Bioengineering & Technology" in protocol["bigAreas"] and "Bioinformatics & Computational Biology" not in protocol["bigAreas"]:
            all_protocols.append(protocol)

    print("tot", len(all_protocols))
    
    # 采样总数的十分之一
    sample_size = len(all_protocols) // 5
    print(sample_size)
    if sample_size == 0:
        raise ValueError("协议数量过少，无法采样。")
    
    protocols = random.sample(all_protocols, sample_size)
    
    # 写入采样的协议
    for protocol in protocols:
        id = str(protocol["id"])
        write_json(f"planning/data/corpus/BioEng/{id}.json", protocol)

def sample_original_protocols_all():
    for big_area, file_suffix in name_mapping.items():
        titles = []
        all_titles = []
        
        # 遍历协议文件夹中的所有文件
        for file in os.listdir("dataset/original_protocol"):
            protocol = read_json("dataset/original_protocol/" + file)
            
            # 检查协议是否属于当前的 bigArea 且不属于 "Bioinformatics & Computational Biology"
            if big_area in protocol["bigAreas"] and "Bioinformatics & Computational Biology" not in protocol["bigAreas"]:
                all_titles.append(protocol["title"])

        # 检查是否有足够的协议满足条件
        if len(all_titles) < 200:
            raise ValueError(f"{big_area} 符合条件的协议数量不足，只有 {len(all_titles)} 个。")

        # 从所有符合条件的协议中随机抽样 200 个
        titles = random.sample(all_titles, 200)

        # 为每个协议的标题生成嵌入并保存为 .npy 文件
        embeddings = [get_embedding(title) for title in tqdm(titles)]
        np.save(f"planning/data/{file_suffix}_sampled_200_emb.npy", embeddings)
    
def dump_sampled_embeddings(model="text-embedding-3-large", domain="Ecology"):
    path = f"planning/data/candidate/{domain}/"
    if model.endswith("bert"):
        embeddings = [
            get_embedding(protocol["ai_generated_description"])
            for filename in tqdm(os.listdir(path))
            if (protocol := read_json(os.path.join(path, filename)))
        ]
    elif model.startswith("text-embedding"):
        embeddings = [
            get_openai_embedding(protocol["ai_generated_description"], model=model)
            for filename in tqdm(sorted(os.listdir(path)))
            if (protocol := read_json(os.path.join(path, filename)))
        ]
    np.save(f"planning/data/{domain}_candidate.npy", embeddings)

# @token_count_decorator(model="text-embedding-3-large", batch=False)
def get_openai_embedding(text, model="text-embedding-3-large"):
    while True:
        try:
            client = OpenAI(
                api_key=os.environ.get("OPENAI_API_KEY"),
            )
            text = text.replace("\n", " ")
            return client.embeddings.create(input=[text], model=model).data[0].embedding
        except openai.APIError as error:
            print(error)